feat(magi): honor AttnMaskSpec on the HF attention backend#2622
Open
HuiyingLi wants to merge 2 commits into
Open
feat(magi): honor AttnMaskSpec on the HF attention backend#2622HuiyingLi wants to merge 2 commits into
HuiyingLi wants to merge 2 commits into
Conversation
The custom-model magi attn_func reads the active AttnMaskSpec (packing / sliding-window / prefix-tree masks via the flex key), but the HF-registered magi forward did not -- so attn_implementation="magi" silently dropped any non-causal mask while backend.attn="magi" applied it. Worse, a model whose attention dispatches on config._attn_implementation (e.g. the custom Qwen2) with backend.attn="magi" falls back to its default attention and drops the mask with no error. Bring the HF forward to parity: it now reads the per-step AttnMaskSpec stamped on the attention module by _set_attn_spec_on_attention() and builds the flex key from it (cp_size==1; the mask rides on `module`, already in the HF attention signature, so no process-global is read inside the interface). Add a consumption guard: a spec armed for a step but never read by a magi forward raises on the next step, turning the silent non-magi fallback into a loud error. CPU unit tests cover the stamping + guard; the GPU forward is exercised by the FFA parity tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Add CPU tests (magi_attention stubbed) that the registered "magi" HF forward builds the flex key and marks the spec consumed when _magi_attn_spec is on the module, and falls back to the dispatched key otherwise. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Contributor
Author
|
/claude review |
1 similar comment
Contributor
Author
|
/claude review |
Contributor
Author
|
/ok to test 81ef112 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Brings the HF magi attention backend (
model.attn_implementation=magi) to parity with the custom-model factory path (model.backend.attn=magi): it now honors an arbitraryAttnMaskSpec(sequence packing / sliding-window / prefix-tree masks), and fails loudly instead of silently dropping a mask.Why
make_magi_attn_func(custom-model factory) already reads the activeAttnMaskSpecand builds the FFA flex key from it. The HF-registered forward (magi_attention_forward, the"magi"entry inALL_ATTENTION_FUNCTIONS) did not — it only built a plain causal / dispatched key. Consequences:attn_implementation=magisilently dropped any non-causal mask (packing / prefix-tree).config._attn_implementation(e.g. the registered customQwen2) configured withbackend.attn=magifalls back to its default attention (flash_attention_2) and drops the mask with no error — the magi forward is never even entered.What changed
magi_attention_forwardreads the per-stepAttnMaskSpecstamped on the attention module and builds the flex key from it (cp_size==1). The mask rides onmodule(already in the HF attention signature) — no process-global is read inside the interface._set_attn_spec_on_attention(model, spec)stamps the spec on the language-backbone attention modules (sibling of_set_cp_group_on_attention), and arms a consumption guard.No behavior change when no spec is active (the default) —
_magi_attn_specis unset → identical to before.Enables / relationship to #2564
This is the generic integration piece extracted so that the cp=1 prefix-tree rollout feature (#2564) works through the HF path. With this landed, #2564 stamps its prefix-tree spec via
_set_attn_spec_on_attentionand usesattn_implementation=magi.Verification
_set_attn_spec_on_attentionstamps only attention modules; the guard raises on an unconsumed spec and passes once consumed.cosine_sim≈0.99999827).Scope
cp_size==1 (prefix-tree / arbitrary masks); the GPU forward is
# pragma: no cover(exercised by the FFA parity tests).🤖 Generated with Claude Code